#!/usr/bin/env python3
'''
Author: Paschalis Agapitos
Project: Mestizajes

Implements the core embedding construction and pooling workflow by generating
vector representations for unique links,
and aggregating link-level embeddings into century- and person-level
representations. The module also provides utilities for normalization,
efficient mean pooling, and downstream dispersion analysis across centuries.
'''

import math
import re
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, Iterable, List, Optional

import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer


# Helper functions

def canon(s: str) -> str:
    """Canonicalize a string by normalizing whitespace."""
    return re.sub(r"[\s_]+", " ", str(s).strip())


def l2_rows(M: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    """L2 normalize each row of a matrix."""
    n = np.linalg.norm(M, axis=1, keepdims=True).clip(min=eps)
    return (M / n).astype(np.float32)


import numpy as np

def mean_pool(M, idxs, w=None, chunk_size: int = 200_000):
    """
    Compute (weighted) mean of rows in M indexed by idxs, without materializing M[idxs].

    - M: (n_vocab, d) float32
    - idxs: list/array of row indices (can be very large)
    - w: optional list/array of per-occurrence weights (same length as idxs)
    - chunk_size: controls peak memory (increase for speed, decrease for safety)
    """
    if not idxs:
        return np.zeros(M.shape[1], dtype=np.float32)

    idxs = np.asarray(idxs, dtype=np.int64)

    d = M.shape[1]
    acc = np.zeros(d, dtype=np.float64)  # small, stable accumulator

    if w is None:
        # Collapse duplicates to reduce work while preserving the exact mean
        uniq, counts = np.unique(idxs, return_counts=True)
        total_w = float(counts.sum())

        for start in range(0, len(uniq), chunk_size):
            u = uniq[start:start + chunk_size]
            c = counts[start:start + chunk_size].astype(np.float64)[:, None]  # (k,1)
            # M[u] is only (k,d) now (k <= chunk_size), so memory is bounded
            acc += (M[u].astype(np.float64) * c).sum(axis=0)

        return (acc / total_w).astype(np.float32)

    # Weighted case
    w = np.asarray(w, dtype=np.float32)
    if w.shape[0] != idxs.shape[0]:
        raise ValueError("Weights length must match idxs length")

    # Collapse duplicates by summing weights per unique index (preserves exact weighted mean)
    uniq, inv = np.unique(idxs, return_inverse=True)
    w_sum = np.bincount(inv, weights=w.astype(np.float64))
    total_w = float(w_sum.sum())

    for start in range(0, len(uniq), chunk_size):
        u = uniq[start:start + chunk_size]
        ws = w_sum[start:start + chunk_size].astype(np.float64)[:, None]
        acc += (M[u].astype(np.float64) * ws).sum(axis=0)

    if total_w == 0.0:
        return np.zeros(d, dtype=np.float32)

    return (acc / total_w).astype(np.float32)


# Embedding generation

class LinkEmbedder:
    """One-time embedding for unique links."""

    def __init__(self, model_name: str = "sentence-transformers/all-mpnet-base-v2", device: Optional[str] = None):
        if device is None:
            device = "cuda:1" if torch.cuda.is_available() else "cpu"
        self.model = SentenceTransformer(model_name, device=device)

    def encode(self, strings: Iterable[str], batch_size: int = 128, normalize: bool = True) -> np.ndarray:
        lst = list(strings)
        vecs = self.model.encode(
            lst,
            batch_size=batch_size,
            show_progress_bar=True,
            normalize_embeddings=normalize,
            convert_to_numpy=True,
        )
        return np.asarray(vecs, dtype=np.float32)

    @staticmethod
    def save(path: Path, arr: np.ndarray) -> None:
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        np.save(path, arr.astype(np.float32))

    @staticmethod
    def load(path: Path) -> np.ndarray:
        arr = np.load(Path(path), mmap_mode="r").astype(np.float32)
        return l2_rows(arr)  # make sure rows are unit norm


# TF-IDF weighting

def build_tfidf_weights(
    groups_to_links: pd.Series,
    sublinear_tf: bool = True,
) -> Dict:
    """
    Compute TF-IDF weights per group for links.

    TF: count within group (sublinear: 1 + log(tf))
    IDF: log((N + 1) / (df + 1)) + 1, where df = #groups containing link
    Returns: weights[group][link] = tf * idf
    """
    weights: Dict = defaultdict(dict)

    # Document frequency across groups (unique within each group)
    N = len(groups_to_links)
    df_counts = Counter()
    for links in groups_to_links:
        df_counts.update(set(links))  # set → each group contributes at most 1

    # Precompute IDF
    idf = {t: math.log((N + 1) / (df_counts[t] + 1)) + 1.0 for t in df_counts}

    # Group-specific TF and weights
    for g, links in groups_to_links.items():
        tf_counts = Counter(links)  # duplicates → higher tf
        for t, tf in tf_counts.items():
            tf_val = (1.0 + math.log(tf)) if (sublinear_tf and tf > 0) else float(tf)
            w = tf_val * idf.get(t, 1.0)
            weights[g][t] = float(w)

    return weights  # dict[group][link] = weight


# Group pooling

class GroupEmbedder:
    """Pool link vectors to group vectors."""

    def __init__(self, unique_links: List[str], embed_matrix: np.ndarray):
        self.unique_links = unique_links
        self.link_to_idx = {s: i for i, s in enumerate(unique_links)}
        self.embed_matrix = l2_rows(embed_matrix)

    def build(self, groups_to_links: pd.Series, weights: Optional[Dict] = None) -> Dict:
        out = {}
        for group, links in groups_to_links.items():
            idxs, wts = [], None
            if weights is not None:
                wts = []
            for l in links:
                i = self.link_to_idx.get(l)
                if i is not None:
                    idxs.append(i)
                    if wts is not None:
                        wts.append(float(weights.get(group, {}).get(l, 1.0)))
            if idxs:
                out[group] = mean_pool(self.embed_matrix, idxs, wts)
        return out  # preserves key type (int for centuries, str for persons)


def build_century_embeddings(
    links_by_century: pd.Series,
    field_of_activity: str,
    vectors_path: Optional[Path] = None,
    sublinear_tf: bool = False,
    device: Optional[str] = None
):
    """
    Build century embeddings from the links by century data.
    
    Parameters
    ----------
    links_by_century : pd.Series
        Series with links grouped by century
    field_of_activity : str
        Name of the field (for vector file naming)
    vectors_path : Optional[Path]
        Optional path to save/load vectors
    sublinear_tf : bool
        Whether to use sublinear TF in TF-IDF weighting
    device : Optional[str]
        Torch device string (e.g., "cuda:1" or "cpu")
    
    Returns
    -------
    century_embeddings : Dict[int, np.ndarray]
        Dictionary of century embeddings
    E_cent : np.ndarray
        Embedding matrix for all unique links
    tfidf_cent_weights : Dict
        TF-IDF weights for centuries
    """
    from .corpus import LinkCorpusBuilder
    
    # Build vocab
    uniq_cent, _ = LinkCorpusBuilder.build_vocab(links_by_century)
    
    # Century vectors (load or compute)
    if vectors_path is not None:
        try:
            E_cent = LinkEmbedder.load(vectors_path)
            if E_cent.shape[0] != len(uniq_cent):
                raise ValueError("Vector file does not match vocab.")
        except Exception:
            print(f"Computing vectors for century links of {field_of_activity}...")
            embedder = LinkEmbedder(device=device)
            E_cent = embedder.encode(uniq_cent, batch_size=128, normalize=True)
            LinkEmbedder.save(vectors_path, E_cent)
            print(f"Saved century vectors to {vectors_path}")
    else:
        print(f"Computing vectors for century links of {field_of_activity}...")
        embedder = LinkEmbedder(device=device)
        E_cent = embedder.encode(uniq_cent, batch_size=128, normalize=True)

    # TF-IDF weights for centuries
    tfidf_cent_weights = build_tfidf_weights(links_by_century, sublinear_tf=sublinear_tf)

    cent_pool = GroupEmbedder(uniq_cent, E_cent)
    century_embeddings: Dict[int, np.ndarray] = cent_pool.build(
        links_by_century,
        weights=tfidf_cent_weights
    )
    
    return century_embeddings, E_cent, tfidf_cent_weights


def build_person_embeddings(
    links_by_person: pd.Series,
    field_of_activity: str,
    vectors_path: Optional[Path] = None,
    sublinear_tf: bool = False,
    device: Optional[str] = None
):
    """
    Build person embeddings from the links by person data.
    
    Parameters
    ----------
    links_by_person : pd.Series
        Series with links grouped by person
    field_of_activity : str
        Name of the field (for vector file naming)
    vectors_path : Optional[Path]
        Optional path to save/load vectors
    sublinear_tf : bool
        Whether to use sublinear TF in TF-IDF weighting
    device : Optional[str]
        Torch device string (e.g., "cuda:1" or "cpu")
    
    Returns
    -------
    person_embeddings : Dict[str, np.ndarray]
        Dictionary of person embeddings
    E_person : np.ndarray
        Embedding matrix for all unique links
    tfidf_person_weights : Dict
        TF-IDF weights for persons
    """
    from .corpus import LinkCorpusBuilder
    
    # Build vocab
    uniq_person, _ = LinkCorpusBuilder.build_vocab(links_by_person)
    
    # Person vectors (load or compute)
    if vectors_path is not None:
        try:
            E_person = LinkEmbedder.load(vectors_path)
            if E_person.shape[0] != len(uniq_person):
                raise ValueError("Vector file does not match vocab.")
        except Exception:
            print(f"Computing vectors for person links of {field_of_activity}...")
            embedder = LinkEmbedder(device=device)
            E_person = embedder.encode(uniq_person, batch_size=128, normalize=True)
            LinkEmbedder.save(vectors_path, E_person)
            print(f"Saved person vectors to {vectors_path}")
    else:
        print(f"Computing vectors for person links of {field_of_activity}...")
        embedder = LinkEmbedder(device=device)
        E_person = embedder.encode(uniq_person, batch_size=128, normalize=True)

    # TF-IDF weights for persons
    tfidf_person_weights = build_tfidf_weights(links_by_person, sublinear_tf=sublinear_tf)

    person_pool = GroupEmbedder(uniq_person, E_person)
    person_embeddings: Dict[str, np.ndarray] = person_pool.build(
        links_by_person,
        weights=tfidf_person_weights
    )
    
    return person_embeddings, E_person, tfidf_person_weights


def compute_century_dispersion_from_person_embeddings(
    parquet_path: Path,
    person_embeddings: Dict[str, np.ndarray],
) -> pd.DataFrame:
    """
    For each century, compute dispersion of person embeddings born in that century.

    Returns a DataFrame with:
      - century
      - n_persons
      - mean_cosine_radius: mean cosine distance (1 - cos) to mean direction
      - std_cosine_radius
      - spherical_variance: 1 - resultant length (classic measure on the sphere)
    """
    from .corpus import LinkCorpusBuilder
    
    builder = LinkCorpusBuilder(parquet_path)
    builder.load()
    df = builder.df.copy()

    df["century_of_birth"] = df["century_of_birth"].astype(int)
    df["Source"] = df["Source"].astype(str)

    persons_by_century = (
        df.groupby("century_of_birth")["Source"]
        .unique()
        .to_dict()
    )

    rows = []

    for century, persons in sorted(persons_by_century.items()):
        vecs = []
        for p in persons:
            v = person_embeddings.get(p)
            if v is not None:
                vecs.append(v)

        if not vecs:
            continue
        
        X = np.vstack(vecs)
        n = X.shape[0]

        # 1. Resultant and Mean Direction
        R_vec = X.sum(axis=0)
        R = float(np.linalg.norm(R_vec))
        
        # Handle edge case where vectors perfectly cancel out
        if R > 1e-9:
            mu = R_vec / R
        else:
            mu = np.zeros_like(R_vec)

        # 2. Individual Cosine Distances
        cos_sims = X @ mu 
        cos_dists = 1.0 - cos_sims

        # 3. Metrics
        mean_cosine_radius = float(cos_dists.mean())
        std_cosine_radius = float(cos_dists.std(ddof=1)) if n > 1 else 0.0
        spherical_variance = 1.0 - (R / n)

        rows.append(dict(
            century=century,
            n_persons=n,
            mean_cosine_radius=mean_cosine_radius,
            std_cosine_radius=std_cosine_radius,
            spherical_variance=spherical_variance
        ))

    return pd.DataFrame(rows).sort_values("century").reset_index(drop=True)
